import matplotlib
import matplotlib.pyplot as plt
import numpy as np

plt.rcParams.update({'font.family':'sans-serif'})
plt.rcParams.update({'font.sans-serif':'Helvetica'})
matplotlib.rcParams['mathtext.fontset'] = 'stix'


def Plot_Manage_Y_and_S(Y):
    plt.scatter(Y[0,:], Y[1:], s = 10, marker="o", label=r"Original $\mathbf{Y}$ Distribution")
    x1 = [-0.3, 0.5]
    y1 = [7/30, -0.3]
    x2 = [-0.3, -0.3]
    y2 = [7/30, 4]
    x3 = [0.5, 6]
    y3 = [-0.3, -0.3]
    plt.plot(x1, y1, color='red', lw=2)
    plt.plot(x2, y2, color='red', lw=2)
    plt.plot(x3, y3, color='red', lw=2)
    x_poly = np.array([x1[0], x1[1], x3[1], 6, x2[1], x1[0]])
    y_poly = np.array([y1[0], y1[1], y3[1], 4, y2[1], y1[0]])
    plt.fill(x_poly, y_poly, 'red', alpha=0.3, label=r"Desired Region $\mathcal{S}$")
    plt.legend(fontsize=12, loc='lower left', bbox_to_anchor=(0.03, 0.03))
    plt.xlabel('Natural TPF Value', fontsize=14)
    plt.ylabel('Natural NCT Value', fontsize=14)
    plt.xlim(-6.1, 6.15)
    plt.ylim(-4.1, 4.1)
    plt.savefig('Y_Manage.pdf', bbox_inches='tight', pad_inches=0.01)


def Plot_Bermuda_Y_and_S(Y):
    plt.hist(Y, bins = 30, density=True, label=r"Original $\mathbf{Y}$ Distribution")
    plt.xlabel(r'Natural NEC ($\mathbf{Y}$) Value', fontsize=14)
    plt.axvline(x=2, color='red')
    plt.axvline(x=0.5, color='red')
    plt.fill_between([0.5, 2], 0, plt.ylim()[1], color='red', alpha=0.3, label=r"Desired Region $\mathcal{S}$")
    plt.ylabel(r'Empirical CDF for $\mathbf{Y}$', fontsize=14)
    plt.ylim(0.0, 0.36)
    plt.legend(fontsize=12, loc='upper left', bbox_to_anchor=(0.03, 0.97))
    plt.savefig('Y_Bermuda.pdf', bbox_inches='tight', pad_inches=0.01)





def PlotRes(res_list):
    
    x_index = list(range(1, len(res_list)+1))

    plt.scatter(x_index, res_list, s = 5)
    plt.plot(x_index, res_list)
    plt.xlabel('Round')
    plt.ylabel('MSE of Parameters')
    plt.legend()
    plt.show()


def PlotSinRes(res_list, node):

    x_index = list(range(1, len(res_list)+1))

    plt.scatter(x_index, res_list, s = 5)
    plt.plot(x_index, res_list, label=node)
    plt.xlabel('Round')
    plt.ylabel('MSE of Parameters')
    plt.legend()
    plt.show()



def find_all_paths_with_costs(graph, start, end, path=[], cost=1.0):
    path = path + [start]
    if start == end:
        return [(path, cost)]
    if start not in graph:
        return []
    paths = []
    for node in graph[start]:
        if node not in path:
            # print(graph[start][node])
            new_paths = find_all_paths_with_costs(graph, node, end, path, cost * graph[start][node][-1])
            for new_path, new_cost in new_paths:
                paths.append((new_path, new_cost))
    return paths

def find_total_costs_to_node(graph, end_node):
    total_costs = {}
    for node in graph:
        if node != end_node:
            costs = find_all_paths_with_costs(graph, node, end_node)
            total_cost = sum(cost for path, cost in costs)
            total_costs[node] = total_cost
    return total_costs


if __name__ == "__main__":
    # Example usage:
    graph = {
        'A': {'B': (0.1, 0.1)},
        'B': {'C': (0.2, 0.2)},
        'C': {'D': (0.3, 0.3), 'E': (2.0, 2.0)},
        'D': {'E': (0.4, 0.4)},
        'E': {},
    }

    end_node = 'E'
    paths_to_node = find_total_costs_to_node(graph, end_node)
    print(paths_to_node)